{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<sub>&copy; 2021 Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> \n",
    "\n",
    "# Keras Classification Model Pruning Using SparseML\n",
    "\n",
    "This notebook provides a step-by-step walkthrough for pruning an already trained (dense) model to enable better performance at inference time using the DeepSparse Inference Engine. You will:\n",
    "- Set up the model and dataset\n",
    "- Integrate the Keras training flow with SparseML\n",
    "- Prune the model using the Keras+SparseML flow\n",
    "- Export to [ONNX](https://onnx.ai/)\n",
    "\n",
    "Reading through this notebook will be reasonably quick to gain an intuition for how to plug SparseML into your Keras training flow.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Requirements\n",
    "To run this notebook, you will need the following packages already installed:\n",
    "* SparseML and SparseZoo;\n",
    "* Tensorflow >=2.1, which includes Keras and TensorBoard;\n",
    "* keras2onnx.\n",
    "\n",
    "You can install any package that is not already present via `pip`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Setting Up the Model and Dataset\n",
    "\n",
    "In this notebook, you will prune a simple convolution neural network model trained on the MNIST dataset. The pretrained model's architecture and weights are downloaded from the SparseZoo model repo. The dataset is downloaded directly from  Keras datasets library."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Up the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cell defines a procedure to download a model from the SparseZoo; additionally, for convenience it also returns the path to an optimization recipe. You construct a Keras model instance from the pretrained to prune in a later step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from sparseml.keras.utils import keras\n",
    "from sparsezoo import Model\n",
    "\n",
    "# Root directory for the notebook artifacts\n",
    "root_dir = \"./notebooks/keras\"\n",
    "\n",
    "def download_model_and_recipe(root_dir: str):\n",
    "    \"\"\"\n",
    "    Download pretrained model and a pruning recipe\n",
    "    \"\"\"\n",
    "    model_dir = os.path.join(root_dir, \"mnist\")\n",
    "    zoo_model = Model(model_dir)\n",
    "\n",
    "\n",
    "    model_file_path = zoo_model.training.default.get_file(\"model.h5\").path\n",
    "    if not os.path.exists(model_file_path) or not model_file_path.endswith(\".h5\"):\n",
    "        raise RuntimeError(\"Model file not found: {}\".format(model_file_path))\n",
    "    recipe_file_path = zoo_model.recipes.default.path\n",
    "    if not os.path.exists(recipe_file_path):\n",
    "        raise RuntimeError(\"Recipe file not found: {}\".format(recipe_file_path))\n",
    "    return model_file_path, recipe_file_path\n",
    "\n",
    "model_file_path, recipe_file_path = download_model_and_recipe(root_dir)\n",
    "\n",
    "print(\"Loading model {}\".format(model_file_path))\n",
    "model = keras.models.load_model(model_file_path)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Up the Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will download the MNIST dataset from Keras datasets library as follows. You will also normalize the data before using it for training and evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "\n",
    "# Number of classes\n",
    "num_classes = 10\n",
    "\n",
    "# Load MNIST dataset\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "\n",
    "# Normalize data\n",
    "x_train = x_train.astype('float32') / 255\n",
    "x_test = x_test.astype('float32') / 255\n",
    "\n",
    "# Add batch dimension (for older TF versions)\n",
    "x_train = numpy.expand_dims(x_train, -1)\n",
    "x_test = numpy.expand_dims(x_test, -1)\n",
    "\n",
    "y_train = keras.utils.to_categorical(y_train, num_classes)\n",
    "y_test = keras.utils.to_categorical(y_test, num_classes)\n",
    "\n",
    "print(\"Dataset loaded and normalized\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before pruning the model, you could run the cell below to verify the accuracy of the model on the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = model.evaluate(x_test, y_test)\n",
    "print(\"Test loss, accuracy: \", res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Pruning the Pretrained Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, you will prune the above pretrained Keras model using the SparseML model optimization library. Recall that a common training workflow in Keras is first to compile the model with the appropriate losses, metrics and an optimizer, then to train the model using the `fit()` method of the `Model` class. The SparseML library makes it easy to extend this training workflow to perform gradual pruning based on weight magnitudes.\n",
    "\n",
    "Given a pretrained model, the pruning workflow can be summarized as follows:\n",
    "1. Create a recipe for pruning, which could be done effectively using the Sparsify toolkit\n",
    "2. Instantiate a Keras optimizer instance (such as SGD or Adam)\n",
    "3. Instantiate a `ScheduledModifierManager` object from the recipe\n",
    "4. Enhance the model and optimizer with pruning data structures by calling the manager's `modify` method. At this step, you have options to define the loggers used during the pruning process. The results of this step are a model to be pruned, an optimizer that should be used and a list of callbacks\n",
    "5. [Optional] Add to the callback list any additional callbacks such as model checkpoint and the SparseML built-in LossesAndMetricsLogging callback\n",
    "6. Compile and fit the modified model using Keras built-in APIs, using the optimizer and callback list\n",
    "7. Erase the pruning information in the enhanced model, and get back the original model with pruned weights\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, you will set up a directory path for logging and the frequency for the logging update."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "\n",
    "# Logging directory\n",
    "log_dir = \"./tensorboard/mnist:\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "print(\"Logging directory: {}\".format(log_dir))\n",
    "\n",
    "# Number of steps before the next logging should take place\n",
    "# Use \"epoch\" or \"batch\" to log at every training epoch or batch (respectively)\n",
    "update_freq = 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cell contains the few steps required for pruning using the SparseML library, ultimately resulting in a modified model, optimizer and a list of callbacks incoporating the optimization logics. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "import math\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "\n",
    "from sparseml.keras.optim import ScheduledModifierManager\n",
    "from sparseml.keras.utils import LossesAndMetricsLoggingCallback, TensorBoardLogger\n",
    "\n",
    "# Training batch size\n",
    "batch_size = 32\n",
    "\n",
    "# Number of steps per epoch\n",
    "steps_per_epoch = math.ceil(len(x_train) / batch_size)\n",
    "\n",
    "# Create a manager from the recipe\n",
    "manager = ScheduledModifierManager.from_yaml(recipe_file_path)\n",
    "\n",
    "# Create an optimizer\n",
    "optimizer = Adam()\n",
    "\n",
    "# Optional: Create a TensorBoardLogger instance\n",
    "loggers = TensorBoardLogger(log_dir=log_dir, update_freq=update_freq)\n",
    "\n",
    "# Modify the model and optimizer\n",
    "model_for_pruning, optimizer, callbacks = manager.modify(model, optimizer, steps_per_epoch, loggers=loggers)\n",
    "\n",
    "# Include your own callbacks. Here you will use the built-in LossesAndMetricsLoggingCallback\n",
    "callbacks.append(LossesAndMetricsLoggingCallback(loggers))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you are ready to compile the modified model using the losses and metrics of your choice, and the optimizer enhanced by the mananger. The last step is to train the model using its `fit()` method, passing in (among other information) the list of callbacks constructed above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile the modified model\n",
    "model_for_pruning.compile(\n",
    "    loss=keras.losses.categorical_crossentropy,\n",
    "    optimizer=optimizer,\n",
    "    metrics=['accuracy'],\n",
    "    run_eagerly=True\n",
    ")\n",
    "\n",
    "# Prune the model\n",
    "model_for_pruning.fit(x_train, y_train, batch_size=batch_size, epochs=manager.max_epochs,\n",
    "                      validation_data=(x_test, y_test), shuffle=True, callbacks=callbacks)\n",
    "\n",
    "print(\"Pruning finished\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is time to verify the accuracy of the pruned model. Later on you can also check the sparsity level of the ONNX version of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Verify the pruned model's accuracy\n",
    "res = model_for_pruning.evaluate(x_test, y_test)\n",
    "print(\"Validation loss, accuracy: \", res)\n",
    "\n",
    "# Erase the enhanced information used for pruning \n",
    "pruned_model = manager.finalize(model_for_pruning)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given that you used a TensorBoardLogger, you can now observe the logging information in TensorBoard."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir ./tensorboard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Examine the Pruned Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can observe the layers of the pruned Keras model using its `layers` property and `get_weights()` method. First, print out the list of layers of this model, and recall that in this optimization recipe we pruned all the Conv2D layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model.layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can observe the weights of any Conv2D that you pruned, and notice that the majority of them are zeros."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Change the layer index to examine the layers. Choose one among 1, 4, 7, 10 as the layer\n",
    "# indices of the pruned layers.\n",
    "layer_index = 10\n",
    "pruned_model.layers[layer_index].get_weights()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To observe the overall sparsity of the model, as well as the sparsity level of each layer, run the following cell. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.keras.utils import sparsity\n",
    "from pprint import pprint\n",
    "\n",
    "model_sparsity, layer_sparsity_dict = sparsity(pruned_model)\n",
    "print(\"Model sparsity: {}\".format(model_sparsity))\n",
    "print(\"Layer sparsities:\")\n",
    "pprint(layer_sparsity_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Exporting to ONNX\n",
    "\n",
    "Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the DeepSparse Engine. For Keras, exporting to ONNX is natively supported. In the cell block below, a convenience class, ModuleExporter(), is used to handle exporting.\n",
    "\n",
    "Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with the DeepSparse Engine.  For saving a custom model, you can override the sample batch for ONNX graph freezing and locations to save to.\n",
    "\n",
    "#### Note:\n",
    "The `keras2onnx` is known to cause issues during the conversion; in particular, if you installed Tensorflow 2.4, you might encounter the `AttributeError: 'KerasTensor' object has no attribute 'graph'`. You might need to downgrade to Tensorflow 2.2 or 2.3 for the ONNX export to work properly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.keras.utils import ModelExporter\n",
    "\n",
    "save_dir = \"keras_classification\"\n",
    "onnx_file_name = \"pruned_mnist.onnx\"\n",
    "\n",
    "exporter = ModelExporter(pruned_model, output_dir=save_dir)\n",
    "exporter.export_onnx(name=onnx_file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "\n",
    "Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:\n",
    "* Pruning different models using SparseML\n",
    "* Trying different pruning and optimization recipes\n",
    "* Running your model on the DeepSparse Engine"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
